{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.3 nn的子类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "\n",
    "torch.set_printoptions(edgeitems=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `nn.Sequential`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Linear(in_features=1, out_features=11, bias=True)\n",
       "  (1): Tanh()\n",
       "  (2): Linear(in_features=11, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_model = nn.Sequential(\n",
    "            nn.Linear(1, 11), # 11的选取是任意的\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(11, 1)) # 这里的11需匹配上一层的输出\n",
    "seq_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (hidden_linear): Linear(in_features=1, out_features=12, bias=True)\n",
       "  (hidden_activation): Tanh()\n",
       "  (output_linear): Linear(in_features=12, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import OrderedDict\n",
    "namedseq_model = nn.Sequential(OrderedDict([\n",
    "    ('hidden_linear', nn.Linear(1, 12)),\n",
    "    ('hidden_activation', nn.Tanh()),\n",
    "    ('output_linear', nn.Linear(12 , 1))\n",
    "]))\n",
    "namedseq_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SubclassModel(\n",
       "  (hidden_linear): Linear(in_features=1, out_features=13, bias=True)\n",
       "  (hidden_activation): Tanh()\n",
       "  (output_linear): Linear(in_features=13, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class SubclassModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden_linear = nn.Linear(1, 13)\n",
    "        self.hidden_activation = nn.Tanh()\n",
    "        self.output_linear = nn.Linear(13, 1)\n",
    "    def forward(self, input):\n",
    "        hidden_t = self.hidden_linear(input)\n",
    "        activated_t = self.hidden_activation(hidden_t)\n",
    "        output_t = self.output_linear(activated_t)\n",
    "        return output_t\n",
    "    \n",
    "subclass_model = SubclassModel()\n",
    "subclass_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seq\n",
      "0.weight              torch.Size([11, 1]) 11\n",
      "0.bias                torch.Size([11])    11\n",
      "2.weight              torch.Size([1, 11]) 11\n",
      "2.bias                torch.Size([1])     1\n",
      "\n",
      "namedseq\n",
      "hidden_linear.weight  torch.Size([12, 1]) 12\n",
      "hidden_linear.bias    torch.Size([12])    12\n",
      "output_linear.weight  torch.Size([1, 12]) 12\n",
      "output_linear.bias    torch.Size([1])     1\n",
      "\n",
      "subclass\n",
      "hidden_linear.weight  torch.Size([13, 1]) 13\n",
      "hidden_linear.bias    torch.Size([13])    13\n",
      "output_linear.weight  torch.Size([1, 13]) 13\n",
      "output_linear.bias    torch.Size([1])     1\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for type_str, model in [('seq', seq_model), ('namedseq', namedseq_model),\n",
    "     ('subclass', subclass_model)]:\n",
    "    print(type_str)\n",
    "    for name_str, param in model.named_parameters():\n",
    "        print(\"{:21} {:19} {}\".format(name_str, str(param.shape), param.numel())) \n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SubclassFunctionalModel(\n",
       "  (hidden_linear): Linear(in_features=1, out_features=14, bias=True)\n",
       "  (output_linear): Linear(in_features=14, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class SubclassFunctionalModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden_linear = nn.Linear(1, 14)\n",
    "        # 去掉了nn.Tanh()\n",
    "        self.output_linear = nn.Linear(14, 1)\n",
    "        \n",
    "    def forward(self, input):\n",
    "        hidden_t = self.hidden_linear(input)\n",
    "        activated_t = torch.tanh(hidden_t) # nn.Tanh对应的函数\n",
    "        output_t = self.output_linear(activated_t)\n",
    "        return output_t\n",
    "    \n",
    "func_model = SubclassFunctionalModel()\n",
    "func_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py36]",
   "language": "python",
   "name": "conda-env-py36-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
